from quant import Order
from quant.utils import InfluxClient, logging
from quant.const import OrderStatus, OrderType, UserEvent


class SpotAdaptor:  # 现货适配器，统一接口、差异行为。用于通用化现货、合约策略
    def __init__(self, maker):
        self.maker = maker
        self.acc = maker.acc
        self.var = maker.variables

        symbol = self.var.symbol
        self.asset_0, self.asset_1 = symbol.split('/')
        self.profit_0 = None  # 初始余额，用于计算profit

    def query(self, wait=False):  # 现货查询balance, orders
        api = self.acc.api
        api.query_balance()
        api.query_open_orders()
        if wait:
            api.join()

    def position(self):
        hold = self.var.hold
        bal = self.acc.balance
        bal = bal[self.asset_0]
        bal = bal['free'] + bal['frozen']
        return bal - hold  # 模仿合约模式，计算仓位。pos = 余额 - 底仓

    def profit(self):
        markets = self.maker.markets
        var = self.maker.variables
        balance = self.acc.balance
        price = markets.mids.get_default(var.exchange, var.symbol)
        if price is None:
            return None

        bal_0 = balance[self.asset_0]
        bal_0 = bal_0['free'] + bal_0['frozen']
        bal_1 = balance[self.asset_1]
        bal_1 = bal_1['free'] + bal_1['frozen']

        profit = (bal_0 - var.hold) * price + bal_1  # 扣除底仓，不计beta、只计alpha。即不计上涨收益、只计策略收益。
        if self.profit_0 is None:
            self.profit_0 = profit
            return None

        return profit - self.profit_0


class FutureAdaptor:  # 合约适配器，统一接口、差异行为。用于通用化现货、合约策略
    def __init__(self, maker):
        self.maker = maker
        self.acc = maker.acc
        self.var = maker.variables
        self.profit_0 = None  # 初始余额，用于计算profit

    def query(self, wait=False):  # 合约查询margin, orders, position
        api = self.acc.api
        api.query_margin()
        api.query_position()
        api.query_open_orders()
        if wait:
            api.join()

    def position(self):
        pos = self.acc.position
        pos = pos[self.var.symbol]
        pos = pos['position']
        return pos

    def profit(self):
        profit = self.acc.margin['usdt']  # 适用于usdt合约，用户可自行补全其它品种
        if self.profit_0 is None:
            self.profit_0 = profit
            return None
        return profit - self.profit_0


class Executor:  # 策略执行器，执行单一side，单一offset的单一订单。简化策略，只需给出所需price、amount，即可自行报单、撤单、波动过滤、精度约分
    def __init__(self, maker, config, side, offset):
        self.maker = maker                          # 策略对象
        self.side = side                            # 买卖方向
        self.offset = offset                        # 开平方向
        self.symbol = config['symbol']              # 交易品种
        self.var = config['var']                    # 波动过滤
        self.prc_decimal = config['prc_precision']  # 价格精度
        self.amt_decimal = config['amt_precision']  # 数量精度
        self.minimum = config['minimum']            # 最小下单量
        self._cid = None                            # client_order_id: 本地订单id，记录Executor正在执行的订单

    def execute(self, price, amount):
        price = round(price, self.prc_decimal)
        amount = round(amount, self.amt_decimal)

        acc = self.maker.acc
        api = acc.api
        pending = acc.orders.get(self._cid)

        if amount < self.minimum:  # 数量较少，本侧交易暂停
            if pending is not None:
                self.cancel(pending)

        elif pending is None:  # 无订单，下单
            order = Order(
                symbol=self.symbol,
                side=self.side,
                price=price,
                amount=amount,
                offset=self.offset,
                order_type=OrderType.PostOnly,
            )
            self._cid = api.place_order(order)  # 下单，记录本地订单id

        elif abs(pending.price - price) > self.var:  # 价格变化较大，撤单
            self.cancel(pending)

    def cancel(self, order):
        if order.status in (OrderStatus.Placing, OrderStatus.Canceling):
            return
        api = self.maker.acc.api
        api.cancel_order(order)


class InfoWriter:  # 交易数据写入数据库
    def __init__(self, maker, client: InfluxClient):
        self.maker = maker
        self.client = client
        self.markets = maker.markets
        self.var = maker.variables
        self.timer = maker.markets.timer

    def start_write_pricing(self, name, calculator, interval=0.1):
        def on_timer():
            value = calculator.get()
            if value is None:
                return
            self.client.write(
                database=self.var.database_name,
                measurement='pricing',
                tags={'log_id': self.var.log_id},
                fields={name: float(value)},
                now=self.markets.now
            )
        self.timer.register(on_timer, interval)

    def start_write_mid(self, exchange, symbol, interval=0.1):
        def on_timer():
            mid = self.markets.mids.get_default(exchange, symbol)
            if mid is None:
                return
            self.client.write(
                database=self.var.database_name,
                measurement='mid',
                tags={'log_id': self.var.log_id},
                fields={'mid': float(mid)},
                now=self.markets.now
            )
        self.timer.register(on_timer, interval)

    def start_write_orders(self, interval=1):
        def on_timer():
            orders = self.maker.acc.orders
            for order in orders.values():
                self.client.write(
                    database=self.var.database_name,
                    measurement='orders',
                    tags={
                        'log_id': self.var.log_id,
                        'side': order.side,
                        'offset': order.offset
                    },
                    fields={
                        'price': float(order.price),
                        'amount': float(order.amount),
                        'status': order.status,
                    },
                    now=self.markets.now
                )
        self.timer.register(on_timer, interval)

    def start_write_execution(self):
        def on_place(order):
            _write('place', order)

        def on_cancel(order):
            _write('cancel', order)

        def _write(operation, order):
            self.client.write(
                database=self.var.database_name,
                measurement='execution',
                tags={
                    'log_id': self.var.log_id,
                    'side': order.side,
                    'offset': order.offset,
                    'operation': operation
                },
                fields={
                    'price': float(order.price),
                    'amount': float(order.amount),
                },
                now=self.markets.now
            )
        info_engine = self.maker.acc.info_engine
        info_engine.subscribe(info_engine.Place, on_place)
        info_engine.subscribe(info_engine.Cancel, on_cancel)

    def start_write_deal(self):
        def on_deal(user_data):
            order, amount = user_data.data
            self.client.write(
                database=self.var.database_name,
                measurement='deal',
                tags={
                    'log_id': self.var.log_id,
                    'side': order.side,
                    'offset': order.offset,
                },
                fields={
                    'price': float(order.price),
                    'amount': float(amount),
                },
                now=self.markets.now
            )
        self.maker.acc.subscribe(UserEvent.Deal, on_deal)

    def start_write_position(self, interval=0.1):
        def on_timer():
            pos = self.maker.adaptor.position()
            self.client.write(
                database=self.var.database_name,
                measurement='position',
                tags={'log_id': self.var.log_id},
                fields={'pos': float(pos)},
                now=self.markets.now
            )
        self.timer.register(on_timer, interval)

    def start_write_profit(self, interval=1):
        def on_timer():
            profit = self.maker.adaptor.profit()
            if profit is None:
                return
            self.client.write(
                database=self.var.database_name,
                measurement='profit',
                tags={'log_id': self.var.log_id},
                fields={'profit': float(profit)},
                now=self.markets.now
            )
        self.timer.register(on_timer, interval)

    def start_write_asset(self, interval=1):
        def on_timer():
            asset = self.maker.acc.evaluate()
            self.client.write(
                database=self.var.database_name,
                measurement='asset',
                tags={'log_id': self.var.log_id},
                fields={'asset': float(asset)},
                now=self.markets.now
            )
        self.timer.register(on_timer, interval)

    def start_write_logging(self):
        def on_info(msg):
            _write('info', msg)

        def on_warning(msg):
            _write('warning', msg)

        def on_error(msg):
            _write('error', msg)

        def on_critical(msg):
            _write('critical', msg)

        def _write(level, msg):
            self.client.write(
                database=self.var.database_name,
                measurement='logging',
                tags={'log_id': self.var.log_id},
                fields={'level': level, 'msg': msg},
                now=self.markets.now,
            )
        logging.subscribe(on_info, logging.INFO)
        logging.subscribe(on_warning, logging.WARN)
        logging.subscribe(on_error, logging.ERROR)
        logging.subscribe(on_critical, logging.CRITICAL)

    def start_write_api_error(self):
        def on_request_fail(endpoint, model, response):
            msg = 'request fail {} {}({})'.format(endpoint, response.text, model)
            _write(msg)

        def on_execution_fail(endpoint, model, response):
            msg = 'execution fail {} {}({})'.format(endpoint, response.text, model)
            _write(msg)

        def _write(msg):
            self.client.write(
                database=self.var.database_name,
                measurement='api_error',
                tags={'log_id': self.var.log_id},
                fields={'msg': msg},
                now=self.markets.now,
            )
        info_engine = self.maker.acc.info_engine
        info_engine.subscribe(info_engine.RequestFail, on_request_fail)
        info_engine.subscribe(info_engine.OperationFail, on_execution_fail)


def create_adaptor(maker_type, maker):
    if maker_type == 'spot':
        return SpotAdaptor(maker)

    if maker_type == 'future':
        return FutureAdaptor(maker)

    err = NotImplementedError('No maker_type named {}'.format(maker_type))
    raise err
